So far in our course, we've covered basic Python
to more advanced features of Python
's array processing and data analysis libraries. While we have gotten into the meat of handling numbers themselves, it would be nice to have a library of tools to visualize these underlying data in a powerful but aesthetic way. The solution, which has become a massive open-source project in its own right, is matplotlib
. From the matplotlib
homepage:
matplotlib
is a python 2D plotting library which produces publication quality figures in a variety of hardcopy formats and interactive environments across platforms. matplotlib can be used in python scripts, the python and ipython shell (ala MATLAB® or Mathematica®†), web application servers, and six graphical user interface toolkits.
matplotlib
tries to make easy things easy and hard things possible. You can generate plots, histograms, power spectra, bar charts, errorcharts, scatterplots, etc, with just a few lines of code. For a sampling, see the screenshots, thumbnail gallery, and examples directory.For the power user, you have full control of line styles, font properties, axes properties, etc, via an object oriented interface or via a set of functions familiar to MATLAB users.
Let's import all of the libraries we will use in this session.
In [ ]:
from __future__ import division
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(12345)
plt.rc('figure', figsize=(5, 5))
from pandas import Series, DataFrame
import pandas as pd
np.set_printoptions(precision=4)
When you use the IPython
notebook, you can print plots to the output of individual cells by including the magic command:
In [ ]:
%matplotlib inline
There are two ways to think about creating and displaying plots using matplotlib
. The first, and simpler, approach is the imperative, "scripting" paradigm. Modeled after the plotting functionality of MATLAB, this gives you an easy way to generate a large quantity of plots.
The second paradigm is the object-oriented approach, which requires a larger amount of initial code, but with a much higher degree of flexibility and robust functionality.
The main module that we use to generate plots is the pyplot
submodule of matplotlib
. According to established convention, we import this module as follows:
In [ ]:
import matplotlib.pyplot as plt
From now on, we will use plt
to denote methods and fields in the pyplot
module. Here is a simple demonstration of the MATLAB approach to plotting.
In [ ]:
x = np.arange(0,10,0.1) # generates an ndarray from 0 to 9.9
plt.plot(np.sin(x))
The plot
function takes an array-like type and produces a line plot of the array. If you give plot
a single array, it will implicitly assume that you mean to plot coordinate pairs $(i,arr[i])$, where $i$ is an integer.
Instead, you can pass two arrays $x$ and $y$ (of the same size), which produces a plot of coordinates $(x[i], y[i])$.
In [ ]:
x = np.arange(0,2.0*np.pi, 0.01)
plt.plot(np.cos(x),np.sin(x))
plt.xlim([-1.1,1.1])
plt.ylim([-1.1,1.1])
It is very simple to customize the style of the plots.
In [ ]:
plt.plot?
By passing in an optional character argument, you can specify the color of the line being plotted.
Character | Color |
---|---|
'b' |
blue |
'g' |
green |
'r' |
red |
'c' |
cyan |
'm' |
magenta |
'y' |
yellow |
'k' |
black |
'w' |
white |
Alternatively, you can specify a custom (e.g. hexadecimal) color by passing a color=#123456
argument.
For customizing the line marker shapes, you can specify from a number of built-in arguments.
Character | Description | Character | Description |
---|---|---|---|
'-' |
solid line style | '3' |
tri_left marker |
'--' |
dashed line style | '4' |
tri_right marker |
'-.' |
dash-dot line style | 's' |
square marker |
':' |
dotted line style | 'p' |
pentagon marker |
'.' |
point marker | '*' |
star marker |
',' |
pixel marker | 'h' |
hexagon1 marker |
'o' |
circle marker | 'H' |
hexagon2 marker |
'v' |
triangle_down marker | '+' |
plus marker |
'^' |
triangle_up marker | 'x' |
x marker |
'<' |
triangle_left marker | 'D' |
diamond marker |
'>' |
triangle_right marker | 'd' |
thin_diamond marker |
'1' |
tri_down marker | '_' |
hline marker |
'2' |
tri_up marker |
There are even more keyword arguments, but we won't go into the details here. Here is a simple example:
In [ ]:
x = np.arange(0,10,0.2)
plt.plot(np.sin(x), '1')
plt.plot(np.cos(x), ':')
plt.plot(np.sqrt(x), 'm', drawstyle='steps-post')
Notice that calling plot
several times in one cell allows you to plot several graphs on one figure.
Now, let's customize the title, labels, legend, and ticks of the plot. In the MATLAB paradigm, we can use simple figure methods like title
, xlabel
, and ylabel
, as well as call the legend
method. To specify a legend string, include the optional argument label
in the plot method.
In [ ]:
x = np.random.randn(1000)
y = np.random.randn(1000)
plt.plot(x.cumsum(), 'k', label='A random walk ($X(n)$)')
plt.plot(y.cumsum(), 'r--', label='Another walk ($Y(n)$)')
# Title and labels
plt.title('This is the figure title')
plt.xlabel('The horizontal axis ($n$)')
plt.ylabel('The vertical axis')
# Tick values
plt.xticks(range(0,1001,250), rotation=30)
plt.yticks(range(-50,51, 20))
plt.legend(loc='best') # especially useful for random data
If you notice carefully, matplotlib
can render $\LaTeX$ in title, axis, and legend strings. Simply include the $\LaTeX$ dollar sign and matplotlib
will do the rest for you. Now, there are tricky grey-areas with this functionality. For example, if you want to typeset the Greek letter $\tau$ on your plots, matplotlib
will not properly interpret the string. (Why is this? See if you can figure out why.) To force matplotlib
to interpret strings literally, you can instead write r'$\tau$'
, which will tell matplotlib
to ignore the formatting ambiguity.
There are tons of ways to customize your plots further, but we'll leave this to your exploration of the matplotlib
documentation.
Whereas in the MATLAB approach, all plotting activity was centered around the matplotlib
figure, the object-oriented approach shifts this attention to the axis.
When you begin plotting, you first initialize a figure and then add axes to the figure. Each axis now functions as its own plotting environment, which allows you to specify all of the previous functions nearly identically as before.
Why go to all of this work to specify the axis objects? The immediate advantage is that you can now easily construct several axes on one figure, which is an ability I have personally found incredibly useful.
Here is a simple way to get started:
In [ ]:
fig = plt.figure(figsize=(9,3)) # instantiate a new figure object
# Add three axes aligned horizontally
ax1 = fig.add_subplot(1,3,1)
ax2 = fig.add_subplot(1,3,2)
ax3 = fig.add_subplot(1,3,3)
x = np.arange(0.0, 10.0, 0.1)
# Simple plot
ax1.plot(x, np.tan(x))
# Histogram plot
ax2.hist(np.random.randn(1000), bins=100)
# Scatter plot, parametric
ax3.scatter(np.sin(x), x*x)
ax1.set_xlabel("Axis 1")
ax2.set_xlabel("Axis 2")
ax3.set_xlabel("Axis 3")
ax1.set_title(r"$\tan(x)$", fontsize=16)
ax2.set_title("Random sampling")
ax3.set_title(r"$(\sin(x),x^2)$", fontsize=16)
Here, we construct subplots by using the add_subplot
method. The first two arguments of add_subplot
indicate the number of rows and columns, respectively. Notice that for axes objects, we use set_xlabel
and set_title
instead of xlabel
and title
, but otherwise the functions work as one might expect compared to the MATLAB approach. This is generally the case for axes methods.
Beyond plot
, matplotlib
provides a host of other plotting methods, depending on exactly what your visualization needs. On display here is the hist
and scatter
methods. hist
takes an array and plots the distribution of values of the array in a bar chart. scatter
is similar to plot
, but it requires exactly two arrays to generate coordinate pairs.
One way to deal with a large number of axes is to think about them as iterable objects. This can dramatically reduce the amount of code requisite to do sophisticated plots. For example:
In [ ]:
fig, axes = plt.subplots(3,3, figsize=(10,10), sharex=True, sharey=True)
for i in range(3):
for j in range(3):
x = np.random.randn(100)
axes[i, j].hist(x, color='g', alpha=0.5)
axes[i, j].set_title("Realization %i,%i" % (i+1,j+1))
plt.subplots_adjust(wspace=0.2,hspace=0.2)
Of note, you can specify whether two (or in the above case, every) subplots share an x or y-axis. This can be a nice technique to reduce the clutter around a plot. Another function that is useful for figure formatting is subplots_adjust
, which allows you to specify the spacing between plots and the margins from the borders of the aggregate figure.
Here we will use the MATLAB approach just for brevity of code. We have already seen plot
fairly extensively, so now we will explore other matplotlib
plotting functions that you might want to explore.
The bar
and barh
methods allow you to generate bar plots, with the distinction that bar
orients the rectangles of the plot along the vertical axis while barh
orients along the horizontal axis. Outside from orientation, both work identically (we will from now on assume bar
).
bar
takes an array denoting the x-coordinates of the left sides of the bars and an array denoting the heights of the bars. Optionally, you can add a scalar value (or array, for each bar) to denote the width of every bar. As with every plotting function, you can then specify color, transparency (alpha), and the legend label of the bars. For bar plots, you can add additional options, xerr
and yerr
, to specify the error bars in the x and y directions for the plot.
As a first example, we simply create a bar chart denoting increasing values:
In [ ]:
vals = np.arange(0,10,1)
plt.bar(vals, vals + 1, 1)
Of course, we can make this much more sophisticated. Here's a fun example that demonstrates some of the main features of bar
.
In [ ]:
width = 0.2
rows = np.arange(0,10, width)
plt.figure(figsize=(10,5))
data1 = 1.0 - 2.0/(rows +1.0) * np.sin(rows)
data2 = 2./(rows + 1.)**2 * np.abs(np.cos(rows))
plt.bar(rows, data1, width, color="y", alpha=0.7,
label="Perceived Knowledge of C")
plt.bar(rows, data2, width, color="b", alpha=0.7, label="Happiness")
plt.legend(loc="best")
plt.xticks([0.2,2.0,4.0,7.0,9.0],
("None", "Pointer syntax", "Pointer arithmetic",
"Passing pointers",
"Function pointers"), rotation=0)
plt.xlabel("Material covered in CSC 161", fontsize=16)
plt.yticks([])
plt.title("C is a frustrating language", fontsize=20)
In [ ]:
x = np.arange(0,10,1)
plt.barh(x, np.cos(x), 0.85, align='center',
xerr= 0.05 * np.random.rand(np.size(x)), alpha=0.4)
plt.yticks(x[:], ("This", "Is", "A", "Bar", "Plot", "With", "Custom",
"Tags", "For", "You", "To", "See"))
It's important to know how to build plots manually from data stored in NumPy
. However, we can also use Pandas
to produce high-quality matplotlib
plots from existing Series
and DataFrame
objects with considerable ease and flexibility.
This section will walk through a variety of options you have at your disposal when building visualizations with Pandas
, but is by no means exhaustive. For more information, check out the Pandas Documentation.
Given a Series
object, one natural approach to plotting is with line plots. This is the default behavior of the method Series.plot
, displayed below.
In [ ]:
s = Series(np.random.randn(10).cumsum(), index=np.arange(0, 100, 10))
s.plot()
This extends naturally to the DataFrame
object. Since DataFrame
objects already label the columns of their internal data, it is also easy to produce legends.
In [ ]:
df = DataFrame(np.random.randn(10, 4).cumsum(0),
columns=['A', 'B', 'C', 'D'],
index=np.arange(0, 100, 10))
df.plot()
Of course, there are many more types of visualization than line plots. In general, one can specify the type of plot a Series
or DataFrame
generates by changing the optional parameter kind
.
Here is an example showcasing the bar
and barh
plots we saw earlier. Additionally, we can specify the specific axis we want as the base of the plot. Pandas
takes care of the formatting as well.
Plotting with Pandas
plays nicely with both the MATLAB-style of generation, as with the prior examples, or with the Object-oriented paradigm, as below.
In [ ]:
fig, axes = plt.subplots(2, 1)
data = Series(np.random.rand(16), index=list('abcdefghijklmnop'))
data.plot(kind='bar', ax=axes[0], color='k', alpha=0.7)
data.plot(kind='barh', ax=axes[1], color='k', alpha=0.7)
Again, DataFrame
objects have similar functionality. Consider the following DataFrame
.
In [ ]:
df = DataFrame(np.random.rand(6, 4),
index=['one', 'two', 'three', 'four', 'five', 'six'],
columns=pd.Index(['A', 'B', 'C', 'D'], name='Genus'))
df
We can plot the data in a bar graph just as we would for a Series
object. Notice that the legend by default is not fixed to any particular location on the plot. This is the "best"
parameter choice for legend location. You can hide the legend by specifying legend=False
.
In [ ]:
df.plot(kind='bar')
Bar plots can also be stacked by using the stacked
parameter. Notice also that when using bar
or barh
, Pandas
takes care of aligning the data properly to its index label, using the column label for the legend.
In [ ]:
df.plot(kind='barh', stacked=True, alpha=0.5)
Now, Wes Mckinney has a collection of restaurant tip data in csv
format that you can download here. We can load up the file and put it directly into a DataFrame
object using read_csv
(more on this next week!).
In [ ]:
tips = pd.read_csv('mckinney-files/tips.csv')
tips.head() # head specifies to display a reasonable amount of output.
We want to cross-tabulate between the day of the week and the size of the party. In other words, we want to count how many parties of one were seated on Friday; how many parties of two; etc, for each day of the week we have data (Friday, Saturday, Sunday, and Thursday). To do this we will use crosstab
.
First, let's look at the columns formed by tips
.
In [ ]:
tips.columns
We can cross-tabulate between the day of the week (Thursday, Friday, Saturday, or Sunday) and the number of guests per party (1-6), using crosstab
.
In [ ]:
party_counts = pd.crosstab(tips["day"], tips["size"])
In [ ]:
party_counts
Now, we can proceed to the analysis. One type of plot would simply show the breakdown of guests given a day of the week.
In [ ]:
party_counts.plot(kind='barh')
Not so enlightening, because there is wide variation in the data. Let's restrict our analysis to parties with a size between 2 and 5 (inclusive).
In [ ]:
# Not many 1- and 6-person parties
party_counts = party_counts.ix[:, 2:5]
We can "normalize" the daily data by dividing the values of a particular entry by the sum of the values along that row. Notice that we use astype(float)
to make sure that no integer division problems are encountered, and we specify axis=0
to say that we are normalizing along the day, not the party.
In [ ]:
# Normalize to sum to 1
party_pcts = party_counts.div(party_counts.sum(1).astype(float), axis=0)
party_pcts
Given this new percentage data, it might make more sense to stack the bars so we can see how the distribution changes from day to day.
In [ ]:
party_pcts.plot(kind='bar', stacked=True)
plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
Parties seem to get much larger on weekends, while couples dominate during weekdays. Not bad for a short analysis!
Of course, one might also ask what the distribution of tip percentages a server can expect to see at a given night. Pandas
helps us answer this with the easy integration of histograms and density plots.
First, let's calculate tip percentages. Luckily, the tip data and the total bill data are already given, so adding a new column is simple.
In [ ]:
plt.figure()
tips['tip_pct'] = tips['tip'] / tips['total_bill']
tips['tip_pct'].hist(bins=50)
plt.title("Histogram of tip percentages")
Looks like the average is right about 15%. Not too shocking. Perhaps instead of a histogram of bins, you want to show a smooth distribution of density. To do so, we can simply choose the kde
style of plot.
In [ ]:
tips['tip_pct'].plot(kind='kde')
You can actually plot histograms and density plots together. Consider the following random data samples.
In [ ]:
comp1 = np.random.normal(0, 1, size=200) # N(0, 1)
comp2 = np.random.normal(10, 2, size=200) # N(10, 4)
By having one cell plot both a histogram and a kernel density estimate plot, we can overlay the two of them together to form a solid understanding of the distribution of data in the set.
In [ ]:
values = Series(np.concatenate([comp1, comp2]))
values.hist(bins=100, alpha=0.3, color='g', normed=True)
values.plot(kind='kde', style='k--')
This allows us to provide a considerable amount of information compactly into one figure. We can also do it without losing the essence of the visualization.
When infering the relationship between two series of data, the scatter plot can provide significant assistance to visualize correlation. Consider the following economic data, which you can download as a csv
file from here.
In [ ]:
macro = pd.read_csv('mckinney-files/macrodata.csv')
macro.set_index(['year','quarter']).tail()
This is a macroeconomic dataset containing the following metrics:
It's a fair bet that this is more information than we want to process at the moment, so we can define a new DataFrame
considering only the essence of the data we need in this example.
Wes McKinney then takes the data and applies transformations to make the visualization easier.
In [ ]:
data = macro[['cpi', 'm1', 'tbilrate', 'unemp']]
trans_data = np.log(data).diff().dropna()
trans_data[-5:]
How does the change in the size of M1 correspond to changes in the unemployment rate? Let's find out!
In [ ]:
plt.scatter(trans_data['m1'], trans_data['unemp'], alpha=0.5)
plt.xlabel("Change in $\log M_1$")
plt.ylabel("Change in unemployment")
plt.title('Changes in log %s vs. log %s' % ('m1', 'unemp'))
It looks like increases in the money supply may have a positive effect on the unemployment rate. Although, it is difficult to say exactly how (we might need a model to infer anything more). Certainly, unemployment seems to be decreasing when the money supply shrinks, according to the data.
Suppose you have a new dataset and you have no idea how the various series are related. One quick approach to get a feel for the relationships, which you can later expand upon in a more thorough analysis, is the scatter matrix. Given $n$ series of data, scatter_matrix
produces an $n\times n$ matrix of scatter plots corresponding to pairs of data.
The question is what to do on the main diagonal; a scatter plot of a data series with itself is quite uninteresting. Instead, the default behavior is to produce a histogram of the data series, but you can specify this to be a kde
plot using the diag
optional parameter.
In [ ]:
pd.scatter_matrix(trans_data, diagonal='kde', color='k')
In the module we will start an application of scipy
and numpy
in order to manipulate images. For further resources we are using ideas and from http://scipy-lectures.github.io/advanced/image_processing/
.
First we need to import the scipy
and numpy
into our file. After doing this we want to write an array into a file.
In [ ]:
%matplotlib inline
from scipy import misc
In [ ]:
l = misc.lena()
misc.imsave('lena.png', l) # uses the Image module (PIL)
In [ ]:
import matplotlib.pyplot as plt
plt.imshow(l)
We can also change the color of our image to reflect the original greyscale.
In [ ]:
plt.imshow(l, cmap=plt.cm.gray)
We can increase the contrast by changing the mininimum and maximum values.
In [ ]:
plt.imshow(l, cmap=plt.cm.gray, vmin=100, vmax=200)
plt.axis('off') # Remove axes and ticks
An interesting image processing technique is drawing contour lines. We can do this using plt.contour
.
In [ ]:
plt.imshow(l, cmap=plt.cm.gray,vmin=100, vmax=200)
plt.contour(l, [60, 150])
plt.axis('off')
We can inspect individual elements for intensity variation using interpolation='nearest'
.
In [ ]:
plt.imshow(l[200:220, 200:220], cmap=plt.cm.gray)
plt.imshow(l[200:220, 200:220], cmap=plt.cm.gray,
interpolation='nearest')
Images are arrays. Consequently, we can use array manipulations that we used from numpy
.
In [ ]:
import scipy
import numpy as np
In [ ]:
lena = scipy.misc.lena()
lena[10:13, 20:23]
lena[100:120] = 255
lx, ly = lena.shape
X, Y = np.ogrid[0:lx, 0:ly]
mask = (X - lx/2)**2 + (Y - ly/2)**2 > lx*ly/4
lena[mask] = 0
lena[range(400), range(400)] = 255
plt.figure(figsize=(3, 3))
plt.axes([0, 0, 1, 1])
plt.imshow(lena, cmap=plt.cm.gray)
plt.axis('off')
We can easily rotate and flip the image using the numpy
library.
In [ ]:
from scipy import ndimage
lena = scipy.misc.lena()
lx, ly = lena.shape
# Cropping
crop_lena = lena[lx/4:-lx/4, ly/4:-ly/4]
# up <-> down flip
flip_ud_lena = np.flipud(lena)
# rotation
rotate_lena = ndimage.rotate(lena, 45)
rotate_lena_noreshape = ndimage.rotate(lena, 45, reshape=False)
plt.figure(figsize=(12.5, 2.5))
plt.subplot(151)
plt.imshow(lena, cmap=plt.cm.gray)
plt.axis('off')
plt.subplot(152)
plt.imshow(crop_lena, cmap=plt.cm.gray)
plt.axis('off')
plt.subplot(153)
plt.imshow(flip_ud_lena, cmap=plt.cm.gray)
plt.axis('off')
plt.subplot(154)
plt.imshow(rotate_lena, cmap=plt.cm.gray)
plt.axis('off')
plt.subplot(155)
plt.imshow(rotate_lena_noreshape, cmap=plt.cm.gray)
plt.axis('off')
plt.subplots_adjust(wspace=0.02, hspace=0.3, top=1, bottom=0.1, left=0,
right=1)
We can filter images by replacing the value of the pixels by a function of adjacent pixels. In the example below we use two different filters. The Gaussian filter sets the value of a pixel to the weighted average of the value of neighboring pixels, where nearby pixels have greater weights. The uniform filter is simply the average value of the pixels a set distance away.
In [ ]:
lena = scipy.misc.lena()
blurred_lena = ndimage.gaussian_filter(lena, sigma=3)
very_blurred = ndimage.gaussian_filter(lena, sigma=5)
local_mean = ndimage.uniform_filter(lena, size=11)
plt.figure(figsize=(9, 3))
plt.subplot(131)
plt.imshow(blurred_lena, cmap=plt.cm.gray)
plt.axis('off')
plt.subplot(132)
plt.imshow(very_blurred, cmap=plt.cm.gray)
plt.axis('off')
plt.subplot(133)
plt.imshow(local_mean, cmap=plt.cm.gray)
plt.axis('off')
plt.subplots_adjust(wspace=0, hspace=0., top=0.99, bottom=0.01,
left=0.01, right=0.99)
We can also sharpen a blurred image. The following shows the original image followed by a blurred image and a resharpened image.
In [ ]:
l = scipy.misc.lena()
blurred_l = ndimage.gaussian_filter(l, 3)
filter_blurred_l = ndimage.gaussian_filter(blurred_l, 1)
alpha = 30
sharpened = blurred_l + alpha * (blurred_l - filter_blurred_l)
plt.figure(figsize=(12, 4))
plt.subplot(131)
plt.imshow(l, cmap=plt.cm.gray)
plt.axis('off')
plt.subplot(132)
plt.imshow(blurred_l, cmap=plt.cm.gray)
plt.axis('off')
plt.subplot(133)
plt.imshow(sharpened, cmap=plt.cm.gray)
plt.axis('off')
Applying the filters we learned to help us blur and sharpen images allow us to denoise an image. However, these filters are not without problems. The Gaussian filter smoothes out the noise, but it also smoothes out the edges of the picture. A median picture smoothes the noise, but it preserves the edges better than the Gaussian filter.
In [ ]:
l = scipy.misc.lena()
l = l[230:290, 220:320]
noisy = l + 0.4*l.std()*np.random.random(l.shape)
gauss_denoised = ndimage.gaussian_filter(noisy, 2)
med_denoised = ndimage.median_filter(noisy, 3)
plt.figure(figsize=(12,2.8))
plt.subplot(131)
plt.imshow(noisy, cmap=plt.cm.gray, vmin=40, vmax=220)
plt.axis('off')
plt.title('noisy', fontsize=20)
plt.subplot(132)
plt.imshow(gauss_denoised, cmap=plt.cm.gray, vmin=40, vmax=220)
plt.axis('off')
plt.title('Gaussian filter', fontsize=20)
plt.subplot(133)
plt.imshow(med_denoised, cmap=plt.cm.gray, vmin=40, vmax=220)
plt.axis('off')
plt.title('Median filter', fontsize=20)
plt.subplots_adjust(wspace=0.02, hspace=0.02, top=0.9, bottom=0, left=0,
right=1)
The median filter is better when working with straight edges (low-curviture images).
In [ ]:
im = np.zeros((20, 20))
im[5:-5, 5:-5] = 1
im = ndimage.distance_transform_bf(im)
im_noise = im + 0.2*np.random.randn(*im.shape)
im_med = ndimage.median_filter(im_noise, 3)
plt.figure(figsize=(12, 5))
plt.subplot(141)
plt.imshow(im, interpolation='nearest')
plt.axis('off')
plt.title('Original image', fontsize=20)
plt.subplot(142)
plt.imshow(im_noise, interpolation='nearest', vmin=0, vmax=5)
plt.axis('off')
plt.title('Noisy image', fontsize=20)
plt.subplot(143)
plt.imshow(im_med, interpolation='nearest', vmin=0, vmax=5)
plt.axis('off')
plt.title('Median filter', fontsize=20)
plt.subplot(144)
plt.imshow(np.abs(im - im_med), cmap=plt.cm.hot, vmin=0, vmax=5, interpolation='nearest')
plt.axis('off')
plt.title('Error', fontsize=20)
plt.subplots_adjust(wspace=0.02, hspace=0.02, top=0.9, bottom=0, left=0, right=1)
Lena
. Once you have a noisy image try using a median and Gaussian filter to smooth the image.ndimage.maximum_filter
, and ndimage.percentile_filter
on the image of concentric squares.scipy.signal.wiener
.